package org.hepeng.workx.mybatis.interceptor;

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.util.deparser.ExpressionDeParser;
import net.sf.jsqlparser.util.deparser.SelectDeParser;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.collections.CollectionUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.hepeng.workx.extension.XLoader;
import org.hepeng.workx.mybatis.MybatisConstant;
import org.hepeng.workx.mybatis.util.WorkXMybatisEnvironment;
import org.hepeng.workx.sqlparse.CountItem;
import org.hepeng.workx.sqlparse.SelectParser;
import org.hepeng.workx.util.PageQuery;
import org.hepeng.workx.util.PageQuerys;
import org.hepeng.workx.util.PageResultSet;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * @author he peng
 */

@Intercepts({
@Signature(type = Executor.class ,
        method = "query" ,
        args = {MappedStatement.class , Object.class , RowBounds.class ,
                ResultHandler.class , CacheKey.class , BoundSql.class})
        ,
@Signature(type = Executor.class ,
method = "query" ,
args = {MappedStatement.class , Object.class , RowBounds.class ,
        ResultHandler.class})}
)
public class PageQueryInterceptor extends AbstractExecutorQueryInvokeInterceptor {

    private static final XLoader<PageQuerySQLRewriter> XLOADER = XLoader.getXLoader(PageQuerySQLRewriter.class , MybatisConstant.MYBATIS_X_POINT_DIRECTORY);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        PageQuery pageQuery = PageQuerys.getPageQuery();
        if (Objects.isNull(pageQuery)) {
            return invocation.proceed();
        }

        PageResultSet<Object> pageResultSet = new PageResultSet<>();
        PageList<Object> pageList = new PageList<>();
        try {
            String originalSql = extractSql(invocation);
            String countSql = rewriteSqlToCount(originalSql);
            Long count = executeCountSql(invocation , countSql);
            BeanUtils.copyProperties(pageResultSet , pageQuery);
            pageResultSet.setTotalRow(count);
            if (count > 0) {
                PageQuerySQLRewriter pageQuerySQLRewriter = XLOADER.getX(WorkXMybatisEnvironment.getDialect(null));
                String pageQuerySql = pageQuerySQLRewriter.rewrite(originalSql , pageQuery.getStartRow() , pageQuery.getPageSize());
                List results = executePageQuerySql(invocation, pageQuerySql);
                pageList.addAll(results);
                pageResultSet.setRecords(results);
                pageList.setPageResultSet(pageResultSet);
            }
        } finally {
            PageQuerys.clear();
        }

        return pageList;
    }


    private List executePageQuerySql(Invocation invocation , String pageQuerySql) throws SQLException {
        MappedStatement ms = extractMappedStatement(invocation);
        return executeSql("_PageQuery" , invocation , pageQuerySql , ms.getResultMaps());
    }

    private List<Object> executeSql(String idSuffix , Invocation invocation , String sql , List<ResultMap> resultMaps) throws SQLException {
        Executor executor = extractExecutor(invocation);
        MappedStatement prototypems = extractMappedStatement(invocation);
        RowBounds rowBounds = extractRowBounds(invocation);
        ResultHandler resultHandler = extractResultHandler(invocation);
        MappedStatement newms = copyFromPrototype(idSuffix , prototypems , resultMaps);
        BoundSql boundSql = extractBoundSql(invocation);
        Object parameterObject = extractParameterObject(invocation);
        BoundSql newBoundSql = new BoundSql(newms.getConfiguration() , sql ,
                boundSql.getParameterMappings() , parameterObject);
        CacheKey cacheKey = executor.createCacheKey(newms , parameterObject , rowBounds , newBoundSql);
        return executor.query(newms, parameterObject, rowBounds, resultHandler, cacheKey, newBoundSql);
    }

    protected Long executeCountSql(Invocation invocation , String countSql) throws SQLException {

        MappedStatement ms = extractMappedStatement(invocation);
        List<ResultMap> resultMaps = new ArrayList<>();
        ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId(),
                Long.class, new ArrayList<>(0)).build();
        resultMaps.add(resultMap);
        List<Object> results = executeSql("_COUNT" , invocation , countSql , resultMaps);
        Long count = 0L;
        if (CollectionUtils.isNotEmpty(results)) {
            Object obj = results.get(0);
            count = Objects.nonNull(obj) ? (Long) obj : count;
        }
        return count;
    }

    protected String rewriteSqlToCount(String originalSql) throws JSQLParserException {
        Select select = (Select) CCJSqlParserUtil.parse(originalSql);
        StringBuilder countSqlBuf = new StringBuilder();
        ExpressionDeParser expressionDeParser = new ExpressionDeParser();
        SelectDeParser parser = new SelectParser(expressionDeParser , countSqlBuf);
        expressionDeParser.setSelectVisitor(parser);
        expressionDeParser.setBuffer(countSqlBuf);
        PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
        List<SelectItem> selectItems = new ArrayList<>();
        SelectItem selectItem = new CountItem();
        selectItems.add(selectItem);
        plainSelect.setSelectItems(selectItems);
        plainSelect.setOrderByElements(null);
        plainSelect.accept(parser);
        return countSqlBuf.toString();
    }
}
